[JAX] Expert Parallelism: JAX primitives + VJPs#3036
Conversation
Greptile SummaryThis PR lands the JAX Expert Parallelism (EP) bindings: XLA FFI handlers wrapping the
Confidence Score: 4/5Safe to merge with one fix: the dispatch-backward partition function declares an output sharding with the wrong rank for grad_topk_weights, causing a JAX compile-time error in any multi-device training run that backpropagates through ep_dispatch. The dispatch-backward partition function specifies PartitionSpec(*resolved, None) for grad_topk_weights, producing a spec one rank wider than the tensor's actual shape. Under SPMD JIT with any mesh, JAX will reject the sharding at compile time. The bug is latent in all backward-training paths and easy to hit once EP is exercised in a real training loop. The rest of the primitive stack, the C++ FFI layer, and the custom_vjp math all look correct. transformer_engine/jax/cpp_extensions/ep.py — specifically the EpDispatchBwdPrimitive.partition method (lines 763-766). Important Files Changed
Sequence DiagramsequenceDiagram
participant PY as Python (ep.py)
participant Prim as JAX Primitives (cpp_extensions/ep.py)
participant FFI as XLA FFI (ep.cpp)
participant NCCL as NCCL EP (nvte_ep_*)
Note over PY: ep_bootstrap()
PY->>FFI: SetEpBootstrapParams(uid, ep_size, ...)
FFI->>NCCL: ncclCommInitRank + nvte_ep_initialize
Note over PY: ep_dispatch() forward
PY->>Prim: ep_prepare(topk_idx)
Prim->>FFI: EpPrepareHandler
FFI->>NCCL: nvte_ep_prepare
NCCL-->>PY: token_counts, EpHandle
PY->>Prim: ep_dispatch_fwd(handle, tokens, topk_weights)
Prim->>FFI: EpDispatchHandler
FFI->>NCCL: nvte_ep_dispatch
NCCL-->>PY: recv_tokens, recv_topk_weights
Note over PY: Expert FFN runs on recv_tokens
PY->>Prim: ep_combine_fwd(handle, weighted_expert_out)
Prim->>FFI: EpCombineHandler
FFI->>NCCL: nvte_ep_combine
NCCL-->>PY: combined output
Note over PY: ep_dispatch() backward
PY->>Prim: ep_dispatch_bwd(handle, g_recv_tokens, g_recv_topk_weights)
Prim->>FFI: EpDispatchBwdHandler
FFI->>NCCL: nvte_ep_dispatch_bwd
NCCL-->>PY: grad_tokens, grad_topk_weights
PY->>Prim: ep_combine_bwd(handle, g_result)
Prim->>FFI: EpCombineBwdHandler
FFI->>NCCL: nvte_ep_combine_bwd
NCCL-->>PY: grad_expert_out
Reviews (5): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile |
| } | ||
|
|
||
| private: | ||
| EpCommManager() = default; |
There was a problem hiding this comment.
If we use stateful FFI calls we could tie to EP communicator to the lifetime of the jax computation rather than the process.
There was a problem hiding this comment.
Cool to learn! I will update it.
| Error_Type EpPrepareFFI(cudaStream_t stream, Buffer_Type topk_idx, Result_Type token_counts, | ||
| Result_Type handle_mem, Result_Type workspace, EpPrepareConfig config) { | ||
| auto topk_dims = topk_idx.dimensions(); | ||
| NVTE_CHECK(topk_dims.size() >= 2, |
There was a problem hiding this comment.
nit: can we return FFI InvalidArgument instead of a NVTE_CHECK for these inputs?
There was a problem hiding this comment.
This is probably a good idea. I suggest we make another follow-up MR to do so for all the FFIs.
|
I would appreciate your help to review this PR @tdophung @jberchtold-nvidia! |
| kernels = kernels.reshape(ep_size, NLE, *kernels.shape[1:]) | ||
|
|
||
| @jax.jit | ||
| def step(idx, toks, w, lk): |
There was a problem hiding this comment.
What does lk stand for?
| leading = _ep_leading_dims(is_outer) | ||
| recv_tokens_aval = jax.core.ShapedArray(leading + (recv_pr, hidden_dim), tok_dtype) | ||
| recv_topk_weights_aval = jax.core.ShapedArray(leading + (recv_pr,), jnp.float32) | ||
| workspace_aval = jax.core.ShapedArray(topk_idx_aval.shape, jnp.int64) |
There was a problem hiding this comment.
Same comment as above about int64
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
| assert ret == 0, f"ncclGetUniqueId failed with code {ret}" | ||
| uid_bytes = bytes(uid_arr) |
There was a problem hiding this comment.
assert disabled by -O in ctypes UID path
assert ret == 0 is silently elided when Python runs under the -O optimisation flag (common in production or Numba/Conda environments). If ncclGetUniqueId fails, uid_bytes would be all zeros; the all-gather propagates those zeros to every rank in the EP group, causing ncclCommInitRank to either produce mismatched communicators or hang indefinitely with no diagnostic message.
| assert ret == 0, f"ncclGetUniqueId failed with code {ret}" | |
| uid_bytes = bytes(uid_arr) | |
| ret = libnccl.ncclGetUniqueId(ctypes.cast(uid_arr, ctypes.c_void_p)) | |
| if ret != 0: | |
| raise RuntimeError(f"ncclGetUniqueId failed with code {ret}") |
…em_reloc gating Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…s, MoE example) Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
for more information, see https://pre-commit.ci
Summary
Third PR in the TE Expert Parallelism (EP) series, built on top of #3034. Lands the JAX bindings: an XLA FFI layer over the
nvte_ep_*C API, a Python wrapper withcustom_vjpfor autograd, mesh-aware sharding rules, a multi-process test suite, and an end-to-end MoE example. NCCLncclEpDispatch/ncclEpCombineare exposed as XLA primitives and work with CUDA-graph capture.Implementation
Public Python API (
transformer_engine/jax/ep.py)ep_dispatch/ep_combinearejax.custom_vjpfunctions: forward is the FFI primitive, backward calls the matchingnvte_ep_*_bwdFFI primitive directly (noep_preparein the bwd — routing state is already cached inhandle.mem). Note thatep_dispatchalso callsep_preparein the forward path, which all-gathers and preprocesses routing maps.XLA FFI bindings (
transformer_engine/jax/csrc/extensions/ep.cpp)Five
XLA_FFI_DEFINE_HANDLER_SYMBOLentries —EpPrepareHandler,EpDispatchHandler,EpCombineHandler,EpDispatchBwdHandler,EpCombineBwdHandler— each calling the correspondingnvte_ep_*C entry point. All markedFFI_CudaGraph_Traitsso they capture cleanly.handle_idis a static FFI attribute baked at jit trace time.Primitives + Python layer (
transformer_engine/jax/cpp_extensions/ep.py, +951 lines)Standard TE primitive plumbing:
abstract_eval(shape/dtype inference),lowering,impl,outer_primitiveregistration, and partitioning rules so the EP collective is treated as a single sharded op by XLA (no spurious resharding around it).Sharding (
transformer_engine/jax/sharding.py, +12 lines)Adds the EP mesh axis to the global mesh resource set so downstream sharding rules can reference it.
Build wiring (
build_tools/jax.py, +41 lines)Threads NCCL EP linkage through the JAX
transformer_engine_jaxextension. No new top-level build flags; rides on the parent PR'sNVTE_BUILD_WITH_NCCL_EP.Tests & example
tests/jax/test_multi_process_ep.py(+690 lines): 13 tests covering bootstrap,ep_prepareshape/handle contracts, primitive-level dispatch/combine identity (uniform + skewed routing),custom_vjpfwd+bwd correctness, and HLO inspection (must not insert XLA collectives outside the EP FFI).tests/jax/multi_process_launch_ep.sh: 4-rank launcher; setsXLA_FLAGSto keep XLA command-buffer capture off for the EP FFI sequence (NCCL EP graph-destroy interaction).examples/jax/ep/ep_moe.py(+394 lines) +run_test_ep.sh: end-to-end MoE with EP, dp=ep=2 mesh, includes a ref-comparison--checkthat verifies fwd+bwd vs a single-process reference.Type of change
Checklist: